

clearvars -except K MAP
NL=1;


%% data (including out of sample)
load('../true_DGP/priorWSigma.mat')
data=readtable('../../Data/data_main_sample.csv');
dd=readtable('../../Data/distance_capitals.csv');

data=data(data.year>=1950&isfinite(data.D),:);
ccodes=unique(data.ccode);
years=unique(data.year(data.year>1950));
N=length(ccodes);
T=length(years);
y=zeros(N,T); D=y; M=y; X=repmat(y,1,1,K);
for n=1:N
    dn=data(data.ccode==ccodes(n),:);
    for t=1:T
        if K==0
            if isempty(setdiff(years(t),dn.year))
                if isfinite(dn.y(years(t)==dn.year))
                    M(n,t)=1; % not missing
                    y(n,t)=dn.y(years(t)==dn.year);
                    D(n,t)=dn.D(years(t)==dn.year);
                end
            end
        elseif K==1
            if isempty(setdiff(years(t),dn.year))
                if isfinite(dn.y(years(t)==dn.year))&&isfinite(dn.Y(years(t)-1==dn.year))
                    M(n,t)=1; % not missing
                    y(n,t)=dn.y(years(t)==dn.year);
                    D(n,t)=dn.D(years(t)==dn.year);
                    X(n,t,1)=log(dn.Y(years(t)-1==dn.year));
                end
            end
        elseif K==2
            if isempty(setdiff(years(t),dn.year))&&isempty(setdiff(years(t)-1,dn.year))
                if isfinite(dn.y(years(t)==dn.year))&&isfinite(dn.TinP(years(t)==dn.year))&&...
                        isfinite(dn.Y(years(t)-1==dn.year))
                    M(n,t)=1; % not missing
                    y(n,t)=dn.y(years(t)==dn.year);
                    D(n,t)=dn.D(years(t)==dn.year);
                    X(n,t,1)=dn.TinP(years(t)==dn.year);
                    X(n,t,2)=log(dn.Y(years(t)-1==dn.year));
                end
            end
        elseif K==3
            if isempty(setdiff(years(t),dn.year))&&isempty(setdiff(years(t)-1,dn.year))
                if isfinite(dn.y(years(t)==dn.year))&&isfinite(dn.TinP(years(t)==dn.year))&&...
                        isfinite(dn.Y(years(t)-1==dn.year))&&isfinite(dn.Trade(years(t)-1==dn.year))
                    M(n,t)=1; % not missing
                    y(n,t)=dn.y(years(t)==dn.year);
                    D(n,t)=dn.D(years(t)==dn.year);
                    X(n,t,1)=dn.TinP(years(t)==dn.year);
                    X(n,t,2)=log(dn.Y(years(t)-1==dn.year));
                    X(n,t,3)=dn.Trade(years(t)-1==dn.year);
                end
            end
        elseif K==4
            if isempty(setdiff(years(t),dn.year))&&isempty(setdiff(years(t)-1,dn.year))
                if isfinite(dn.y(years(t)==dn.year))&&isfinite(dn.TinP(years(t)==dn.year))&&...
                        isfinite(dn.Y(years(t)-1==dn.year))&&isfinite(dn.Trade(years(t)-1==dn.year))&&...
                        isfinite(dn.yearsDem(years(t)==dn.year))
                    M(n,t)=1; % not missing
                    y(n,t)=dn.y(years(t)==dn.year);
                    D(n,t)=dn.D(years(t)==dn.year);
                    X(n,t,1)=dn.TinP(years(t)==dn.year);
                    X(n,t,2)=log(dn.Y(years(t)-1==dn.year));
                    X(n,t,3)=dn.Trade(years(t)-1==dn.year);
                    X(n,t,4)=dn.yearsDem(years(t)==dn.year);
                end
            end
        end
    end
end
for k=1:K
    X(:,:,k)=M.*(X(:,:,k)-sum(sum(X(:,1:50,k)))/sum(sum(M(:,1:50))));
end
clear dn t k

Z=10^(-6)*table2array(dd(:,2:end));
KZ=size(Z,3);

ep_sd=sqrt(diag(Sigma));

% Sparse-grid integration
[sgi_nodes,sgi_weights]=nwspgr('KPN',2,8);

cnames=cell(N,1);
for n=1:N
    a=unique(data.idacr(data.ccode==ccodes(n)));
    cnames{n}=a{1};
end

ccodes=table(ccodes,cnames);

clear a cnames data dd n


%% parameter estimates
alpha=MAP(1:N);
theta=MAP(N+(1:2));
beta=MAP(N+2+(1:(2*N)));
v=MAP(3*N+2+(1:N));
f=MAP(4*N+2+(1:N));
vs=MAP(5*N+2+(1:N));
xi=MAP(6*N+2+(1:K));
gamma=MAP(6*N+2+K+(1:KZ));


%% log-likelihood
ll=-log_posterior(MAP,D(:,1:50),M(:,1:50),X(:,1:50,:),Z,a0,om_a,th0,om_th,b0A,b0D,om_b,s_v,d_v,f0,om_f,s_vs,d_vs)-...
    (sum(-0.5*((alpha-a0)/om_a).^2)+sum(-0.5*((theta-th0)/om_th).^2)+...
    sum(-0.5*((beta(1:N)-b0A)/om_b).^2)+sum(-0.5*((beta(N+(1:N))-b0D)/om_b).^2)+...
    sum(-(s_v+1)*log(v)-d_v*v.^(-1))+sum(-0.5*((f-f0)/om_f).^2)+...
    sum(-(s_vs+1)*log(vs)-d_vs*vs.^(-1)));


%% one-step-ahead forecasts
ZZ=zeros(N,N);
for kz=1:KZ
    ZZ=ZZ+Z(:,:,kz)*gamma(kz);
end 
Q=diag(v.*ep_sd)*exp(-ZZ)*diag(v.*ep_sd);
P=kron(eye(2),tril(Q)+tril(Q,-1)');
P=P\eye(2*N);

XX=zeros(N,T);
Pd=repmat(sqrt(diag(P(:,:,1)\eye(2*N))),1,T);
B=repmat(beta,1,T);
for t=1:T
    XX(:,t)=permute(X(:,t,:),[1,3,2])*xi;
end

NS=length(sgi_weights);
U0=zeros(N*T,NS);
U1=zeros(N*T,NS);

for ns=1:NS
	U0(:,ns)=reshape(exp(alpha+theta(1)*(sgi_nodes(ns,1)*Pd(1:N,:)+B(1:N,:)+sgi_nodes(ns,2)*ep_sd)),N*T,1);
    U1(:,ns)=reshape(exp(alpha+theta(2)*(sgi_nodes(ns,1)*Pd(N+(1:N),:)+B(N+(1:N),:)+sgi_nodes(ns,2)*ep_sd)-f-XX),N*T,1);
end
U0=reshape(M,N*T,1).*((U0./(1+U0))*sgi_weights);
U1=reshape(M,N*T,1).*((U1./(1+U1))*sgi_weights);

D_hat=1*reshape(U1>=U0,N,T);


%% observed and predicted transitions into or out of democracy
trns=zeros(N,T);
trns_h=trns;
for n=1:N
    for t=2:T
        if M(n,t-1)==1&&M(n,t)==1
            trns(n,t)=D(n,t)-D(n,t-1);
            trns_h(n,t)=D_hat(n,t)-D_hat(n,t-1);
        end
    end
end

% correct predictions at exact and 5-year windows
correct0=0; correct5=0;
for n=1:N
    for t=2:50
        if trns(n,t)~=0
            correct0=correct0+1*(trns_h(n,t)==trns(n,t));
            correct5=correct5+1*(sum(trns_h(n,max(2,t-2):t+2)==trns(n,t))>0);
        end
    end
end


%% print results
disp(['no-learning model with ',num2str(K),' covariates:'])
disp(' ')
disp(['observations = ',num2str(sum(sum(M(:,1:50))))])
disp(['log-likelihood = ',num2str(ll)])
disp(['% correctly predicted choices = ',num2str(sum(sum(M(:,1:50).*(D(:,1:50)==D_hat(:,1:50))))/sum(sum(M(:,1:50))))])
disp(['% correctly predicted transitions within 0 years = ',num2str(correct0/sum(sum(abs(trns(:,1:50)))))])
disp(['% correctly predicted transitions within 5 years = ',num2str(correct5/sum(sum(abs(trns(:,1:50)))))])
disp(' ')
disp(' ')



